/*
 * demo_HSMM_batchLQR01.cpp
 *
 * HSMM combined with batch LQR.
 *
 * If this code is useful for your research, please cite the related publication:
 * @incollection{Calinon19chapter,
 * 	 author="Calinon, S. and Lee, D.",
 * 	 title="Learning Control",
 * 	 booktitle="Humanoid Robotics: a Reference",
 * 	 publisher="Springer",
 * 	 editor="Vadakkepat, P. and Goswami, A.", 
 * 	 year="2019",
 * 	 doi="10.1007/978-94-007-7194-9_68-1",
 * 	 pages="1--52"
 * }
 *
 * Authors: Sylvain Calinon, Philip Abbet
 *
 * This file is part of PbDlib, https://www.idiap.ch/software/pbdlib/
 * 
 * PbDlib is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 * 
 * PbDlib is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with PbDlib. If not, see <https://www.gnu.org/licenses/>.
 */

#include <stdio.h>
#include <armadillo>

#include <gfx2.h>
#include <gfx_ui.h>
#include <GLFW/glfw3.h>
#include <imgui.h>
#include <imgui_impl_glfw_gl2.h>

using namespace arma;


/***************************** ALGORITHM SECTION *****************************/

typedef std::vector<vec> vector_list_t;
typedef std::vector<mat> matrix_list_t;


//-----------------------------------------------------------------------------
// Contains all the parameters used by the algorithm. Some of them are
// modifiable through the UI, others are hard-coded.
//-----------------------------------------------------------------------------
struct parameters_t {
	int    nb_data;   // Number of datapoints in a trajectory
	int    nb_states; // Number of hidden states in the HSMM
	double rfactor;   // Control cost in LQR
	double dt;        // Time step duration
};


//-----------------------------------------------------------------------------
// Model trained using the algorithm
//-----------------------------------------------------------------------------
struct model_t {
	parameters_t  parameters; // Parameters used to train the model

	vector_list_t mu;
	matrix_list_t sigma;
	mat           transitions;
	vec           states_priors;
	mat           H;
};


//-----------------------------------------------------------------------------
// Likelihood of datapoint(s) to be generated by a Gaussian parameterized by
// center and covariance.
//
// Inputs:
//   - Data:  D x N array representing N datapoints of D dimensions.
//   - Mu:    D x 1 vector representing the center of the Gaussian.
//   - Sigma: D x D array representing the covariance matrix of the Gaussian.
//
// Output:
//   - prob:  1 x N vector representing the likelihood of the N datapoints.
//-----------------------------------------------------------------------------
arma::vec gaussPDF(mat Data, colvec Mu, mat Sigma) {

	int nbVar = Data.n_rows;
	int nbData = Data.n_cols;
	Data = Data.t() - repmat(Mu.t(), nbData, 1);

	vec prob = sum((Data * inv(Sigma)) % Data, 1);

	prob = exp(-0.5 * prob) / sqrt(pow((2 * datum::pi), nbVar) * det(Sigma) + DBL_MIN);

	return prob;
}


//-----------------------------------------------------------------------------
// Initialization of Gaussian Mixture Model (GMM) parameters by clustering 
// an ordered dataset into equal bins
//-----------------------------------------------------------------------------
void init_GMM_kbins(const matrix_list_t& data, model_t &model) {

	// Regularization term to avoid numerical instability
	const double diag_reg_fact = 1e-4;

	model.mu.clear();
	model.sigma.clear();

	// Delimit the cluster bins
	uvec timing_sep = conv_to<uvec>::from(
		round(linspace<vec>(0, model.parameters.nb_data, model.parameters.nb_states + 1))
	);

	// Compute statistics for each bin
	for (int i = 0; i < model.parameters.nb_states; ++i) {
		span id(timing_sep(i), timing_sep(i + 1) - 1);
		int nb_ids = id.b - id.a + 1;

		mat values(data[0].n_rows, data.size() * nb_ids);
		for (int n = 0; n < data.size(); ++n)
			values(span::all, span(n * nb_ids, (n + 1) * nb_ids - 1)) = data[n](span::all, id);

		model.mu.push_back(mean(values, 1));
		model.sigma.push_back(cov(values.t()) + eye(values.n_rows, values.n_rows) * diag_reg_fact);
	}
}


//-----------------------------------------------------------------------------
// Estimation of HMM parameters with an EM algorithm
//-----------------------------------------------------------------------------
void EM_HMM(const matrix_list_t& data, model_t &model) {

	const int nb_var = data[0].n_rows;
	const int nb_samples = data.size();
	const int nb_data = nb_samples * model.parameters.nb_data;

	const int nb_max_steps = 50;			// Maximum number of iterations allowed
	const int nb_min_steps = 5;				// Minimum number of iterations allowed
	const double max_diff_log_likelihood = 1e-4;	// Likelihood increase threshold
													// to stop the algorithm
	const double diag_reg_fact			 = 1e-4;	//Regularization term

	mat all_data(nb_var, nb_data);
	for (int i = 0; i < nb_samples; ++i)
		all_data(span::all, span(i * model.parameters.nb_data, (i + 1) * model.parameters.nb_data - 1)) = data[i];

	std::vector<double> log_likelihoods;

	for (int iter = 0; iter < nb_max_steps; ++iter) {

		vector_list_t c;
		mat GAMMA = zeros(model.parameters.nb_states, nb_samples * model.parameters.nb_data);
		mat GAMMA_INIT = zeros(model.parameters.nb_states, nb_samples);
		mat GAMMA_TRK = zeros(model.parameters.nb_states, nb_samples * (model.parameters.nb_data - 1));
		cube ZETA = zeros(model.parameters.nb_states, model.parameters.nb_states,
						  nb_samples * (model.parameters.nb_data - 1));

		// E-step
		for (int n = 0; n < nb_samples; ++n) {

			// Emission probabilities
			mat B(model.parameters.nb_states, model.parameters.nb_data);
			for (int i = 0; i < model.parameters.nb_states; ++i)
				B(i, span::all) = gaussPDF(data[n], model.mu[i], model.sigma[i]).t();

			// Forward variable ALPHA (rescaled, to avoid underflow issues)
			mat ALPHA(model.parameters.nb_states, model.parameters.nb_data);
			vec c_(model.parameters.nb_data);

			ALPHA(span::all, 0) = model.states_priors % B(span::all, 0);
			c_(0) = 1.0 / sum(ALPHA(span::all, 0) + DBL_MIN);
			ALPHA(span::all, 0) = ALPHA(span::all, 0) * c_(0);

			for (int t = 1; t < model.parameters.nb_data; ++t) {
				ALPHA(span::all, t) = (ALPHA(span::all, t - 1).t() * model.transitions).t() % B(span::all, t); 
				c_(t) = 1.0 / sum(ALPHA(span::all, t) + DBL_MIN);
				ALPHA(span::all, t) = ALPHA(span::all, t) * c_(t);
			}

			c.push_back(c_);

			// Backward variable BETA (rescaled)
			mat BETA(model.parameters.nb_states, model.parameters.nb_data);

			BETA(span::all, model.parameters.nb_data - 1) =
				ones(model.parameters.nb_states, 1) * c_(model.parameters.nb_data - 1);

			for (int t = model.parameters.nb_data - 2; t >= 0; --t) {
				BETA(span::all, t) = model.transitions * (BETA(span::all, t + 1) % B(span::all, t + 1));
				BETA(span::all, t) = min(BETA(span::all, t) * c_(t), ones(BETA.n_rows) * DBL_MAX);
			}

			// Intermediate variable GAMMA
			mat GAMMA_ = (ALPHA % BETA) / repmat(sum(ALPHA % BETA) + DBL_MIN, model.parameters.nb_states, 1);

			GAMMA(span::all, span(n * model.parameters.nb_data, (n + 1) * model.parameters.nb_data - 1)) = GAMMA_;
			GAMMA_INIT(span::all, n) = GAMMA_(span::all, 0);
			GAMMA_TRK(span::all, span(n * (model.parameters.nb_data - 1), (n + 1) * (model.parameters.nb_data - 1) - 1)) =
				GAMMA_(span::all, span(0, model.parameters.nb_data - 2));

			// Intermediate variable ZETA (fast version, by considering scaling factor)
			for (int i = 0; i < model.parameters.nb_states; ++i) {
				for (int j = 0; j < model.parameters.nb_states; ++j) {
					ZETA(span(i), span(j), span(n * (model.parameters.nb_data - 1),
												(n + 1) * (model.parameters.nb_data - 1) - 1)) =
						model.transitions(i, j) *
						(ALPHA(i, span(0, model.parameters.nb_data - 2)) %
						 B(j, span(1, model.parameters.nb_data - 1)) %
						 BETA(j, span(1, model.parameters.nb_data - 1))
						);
				}
			}
		}

		model.H = GAMMA / repmat(sum(GAMMA, 1) + DBL_MIN, 1, GAMMA.n_cols);

		// M-step
		for (int i = 0; i < model.parameters.nb_states; ++i) {

			// Update the centers
			model.mu[i] = all_data * model.H(i, span::all).t();

			// Update the covariance matrices
			mat data_tmp = all_data - repmat(model.mu[i], 1, nb_data);
			model.sigma[i] = data_tmp * diagmat(model.H(i, span::all)) * data_tmp.t() +	// Eq. (54) Rabiner
							 eye(nb_var, nb_var) * diag_reg_fact;	// Regularization term
		}

		// Update initial state probability vector
		model.states_priors = mean(GAMMA_INIT, 1);

		// Update transition probabilities
		model.transitions = mat(sum(ZETA, 2)) / repmat(sum(GAMMA_TRK, 1) + DBL_MIN, 1, model.parameters.nb_states);

		// Compute the average log-likelihood through the ALPHA scaling factors
		log_likelihoods.push_back(0.0);
		for (int n = 0; n < nb_samples; ++n)
			log_likelihoods[iter] = log_likelihoods[iter] - sum(log(c[n]));

		log_likelihoods[iter] = log_likelihoods[iter] / nb_samples;

		// Stop the algorithm if EM converged
		if (iter >= nb_min_steps) {
			if (log_likelihoods[iter] - log_likelihoods[iter - 1] < max_diff_log_likelihood)
				break;
		}
	}
}


//-----------------------------------------------------------------------------
// Learn the model from the demonstrations
//-----------------------------------------------------------------------------
void learn(const matrix_list_t& data, model_t &model) {

	init_GMM_kbins(data, model);

	// Left-right model initialization
	model.transitions = zeros(model.parameters.nb_states, model.parameters.nb_states);

	for (int i = 0; i < model.parameters.nb_states - 1; ++i) {
		model.transitions(i, i) = 1.0 - (double) model.parameters.nb_states / model.parameters.nb_data;
		model.transitions(i, i + 1) = (double) model.parameters.nb_states / model.parameters.nb_data;
	}

	model.transitions(model.parameters.nb_states - 1, model.parameters.nb_states - 1) = 1.0;

	model.states_priors = zeros(model.parameters.nb_states);
	model.states_priors(0) = 1.0;

	EM_HMM(data, model);

	// Removal of self-transition (for HSMM representation) and normalization
	model.transitions = model.transitions - diagmat(model.transitions) +
						eye(model.parameters.nb_states, model.parameters.nb_states) * DBL_MIN;

	model.transitions(model.parameters.nb_states - 1, model.parameters.nb_states - 1) = 1.0;

	model.transitions = model.transitions / repmat(sum(model.transitions, 1), 1, model.parameters.nb_states);
}


//-----------------------------------------------------------------------------
// Compute a reproduction using batch LQR
//-----------------------------------------------------------------------------
mat compute_LQR(const model_t& model, const vec& start_point) {

	// Minimum variance of state duration (regularization term)
	const double min_sigma_Pd = 2e-1; 

	// Number of maximum duration step to consider in the HSMM (2.5 is a safety factor)
	const int nbD = round(2.5f * (float) model.parameters.nb_data / model.parameters.nb_states);

	// Dimension of position data (here: x1,x2)
	const int nb_var_pos = 2;


	// Post-estimation of the state duration from data (for HSMM representation)
	//--------------------------------------------------------------------------
	std::vector< std::vector<double> > st(model.parameters.nb_states);

	urowvec hmax = index_max(model.H, 0);

	unsigned int current_state = hmax(0);
	unsigned int count = 1;

	for (int t = 0; t < hmax.size(); ++t) {
		if (hmax(t) == current_state) {
			++count;
		} else {
			st[current_state].push_back(log(count));
			count = 1;
			current_state = hmax(t);
		}
	}
	st[current_state].push_back(log(count));

	// Compute state duration as Gaussian distribution
	vector_list_t Mu_Pd;
	matrix_list_t Sigma_Pd;
	for (int i = 0; i < model.parameters.nb_states; ++i) {
		if (!st[i].empty()) {
			vec st_(st[i].size());
			for (int j = 0; j < st[i].size(); ++j)
				st_(j) = st[i][j];

			Mu_Pd.push_back(vec({ mean(st_) }));
			Sigma_Pd.push_back(cov(st_) + min_sigma_Pd);
		} else {
			Mu_Pd.push_back(vec({ 0.0 }));
			Sigma_Pd.push_back(cov(vec({ 0.0 })) + min_sigma_Pd);
		}
	}


	// Reconstruction of states probability sequence
	//----------------------------------------------

	// Precomputation of duration probabilities 
	mat Pd(model.parameters.nb_states, nbD + 1);

	vec logs = log(linspace<vec>(0, nbD, nbD + 1));

	for (int i = 0; i < model.parameters.nb_states; ++i)
		Pd(i, span::all) = gaussPDF(logs.t(), Mu_Pd[i], Sigma_Pd[i]).t();

	// Reconstruction of states sequence 
	mat h = zeros(model.parameters.nb_states, model.parameters.nb_data);

	for (int t = 0; t < model.parameters.nb_data; ++t) {
		for (int i = 0; i < model.parameters.nb_states; ++i) {
			if (t < nbD)
				h(i, t) = model.states_priors(i) * Pd(i, t);

			for (int d = 1; d <= std::min(t, nbD); ++d)
				h(i, t) = h(i, t) + mat(h(span::all, t - d).t() * model.transitions(span::all, i) * Pd(i, d - 1))(0, 0);
		}
	}

	h = h / repmat(sum(h, 0) + DBL_MIN, model.parameters.nb_states, 1);


	// Batch LQR reproduction
	//-----------------------

	// Dynamical System settings (discrete version), see Eq. (33)
	mat A = kron(mat({{ 1.0, model.parameters.dt }, { 0.0, 1.0 }}), eye(nb_var_pos, nb_var_pos));
	mat B = kron(mat({{ 0.0, model.parameters.dt }}).t(), eye(nb_var_pos, nb_var_pos));
	mat C = kron(mat({{ 1.0, 0.0 }}), eye(nb_var_pos, nb_var_pos));

	// Control cost matrix
	mat R = eye(nb_var_pos, nb_var_pos) * model.parameters.rfactor;
	R = kron(eye(model.parameters.nb_data - 1, model.parameters.nb_data - 1), R);

	// Build CSx and CSu matrices for batch LQR, see Eq. (35)
	mat CSu = zeros(nb_var_pos * model.parameters.nb_data, nb_var_pos * (model.parameters.nb_data - 1));
	mat CSx = kron(ones(model.parameters.nb_data, 1), eye(nb_var_pos, nb_var_pos * 2));

	mat M = zeros(B.n_rows, 2 * model.parameters.nb_data);

	int n = 2 * model.parameters.nb_data - 2;
	M(span::all, span(n, n + 1)) = B;

	for (int n = 1; n < model.parameters.nb_data; ++n) {
		span id1(n * nb_var_pos, (n + 1) * nb_var_pos - 1);
		span id2(0, n * nb_var_pos - 1);
		int n2 = 2 * model.parameters.nb_data - n * 2;

		CSx.rows(id1) = CSx.rows(id1) * A;
		CSu(id1, id2) = C * M(span::all, span(n2, n2 + n * 2 - 1));

		M(span::all, span(n2 - 2, n2 - 1)) = A * M(span::all, span(n2, n2 + 1));
	}

	// Create single Gaussian N(MuQ,SigmaQ) based on optimal state sequence q, see Eq. (27)
	urowvec qList = index_max(h, 0);
	mat MuQ(nb_var_pos, qList.size());
	mat sigma_(nb_var_pos, nb_var_pos * qList.size());

	for (int i = 0; i < qList.size(); ++i) {
		MuQ(span::all, i) = model.mu[qList(i)];
		sigma_(span::all, span(i * nb_var_pos, (i + 1) * nb_var_pos - 1)) = model.sigma[qList(i)];
	}

	MuQ = reshape(MuQ, MuQ.n_elem, 1);

	mat SigmaQ = (kron(ones(model.parameters.nb_data, 1), eye(nb_var_pos, nb_var_pos)) * sigma_) %
				 kron(eye(model.parameters.nb_data, model.parameters.nb_data), ones(nb_var_pos, nb_var_pos));

	// Set matrices to compute the damped weighted least squares estimate
	mat CSuInvSigmaQ = (pinv(mat(SigmaQ.t())) * CSu).t();
	mat Rq = CSuInvSigmaQ * CSu + R;

	// Reproductions
	vec X = zeros(nb_var_pos * 2);
	X(span(0, nb_var_pos - 1)) = start_point;

	mat rq = CSuInvSigmaQ * (MuQ - CSx * X);
	mat u = pinv(Rq) * rq;

	return reshape(CSx * X + CSu * u, nb_var_pos, model.parameters.nb_data);
}


/****************************** HELPER FUNCTIONS *****************************/

static void error_callback(int error, const char* description){
	fprintf(stderr, "Error %d: %s\n", error, description);
}


//-----------------------------------------------------------------------------
// Colors of the displayed lines and gaussians
//-----------------------------------------------------------------------------
const mat COLORS({
	{ 0.0,  0.0,  1.0  },
	{ 0.0,  0.5,  0.0  },
	{ 1.0,  0.0,  0.0  },
	{ 0.0,  0.75, 0.75 },
	{ 0.75, 0.0,  0.75 },
	{ 0.75, 0.75, 0.0  },
	{ 0.25, 0.25, 0.25 },
});


//-----------------------------------------------------------------------------
// Create a demonstration (with a length of 'timestamps.size()') from a
// trajectory (of any length)
//-----------------------------------------------------------------------------
mat sample_trajectory(const vector_list_t& trajectory, int nb_data) {

	// Resampling of the trajectory
	vec x(trajectory.size());
	vec y(trajectory.size());
	vec x2(trajectory.size());
	vec y2(trajectory.size());

	for (size_t i = 0; i < trajectory.size(); ++i) {
		x(i) = trajectory[i](0);
		y(i) = trajectory[i](1);
	}

	vec from_indices = linspace<vec>(0, trajectory.size() - 1, trajectory.size());
	vec to_indices = linspace<vec>(0, trajectory.size() - 1, nb_data);

	interp1(from_indices, x, to_indices, x2, "*linear");
	interp1(from_indices, y, to_indices, y2, "*linear");

	// Create the demonstration
	mat demo(2, nb_data);
	for (int i = 0; i < nb_data; ++i) {
		demo(0, i) = x2[i];
		demo(1, i) = y2[i];
	}

	return demo;
}


//-----------------------------------------------------------------------------
// Contains all the needed infos about the state of the application (values of
// the parameters modifiable via the UI, which action the user is currently
// doing, ...)
//-----------------------------------------------------------------------------
struct gui_state_t {
	// Indicates if the user is currently drawing a new demonstration
	bool is_drawing_demonstration;

	// Indicates if the parameters were modified through the UI
	bool are_parameters_modified;

	// Indicates if the reproductions must be recomputed
	bool must_recompute;

	// Parameters modifiable via the UI (they correspond to the ones declared
	// in parameters_t)
	int parameter_nb_data;
	int parameter_nb_states;
	float parameter_rfactor;
};


/******************************* MAIN FUNCTION *******************************/

int main(int argc, char **argv){
	arma_rng::set_seed_random();

	// Model
	model_t model;

	// Parameters
	model.parameters.nb_data   = 200;
	model.parameters.nb_states = 6;
	model.parameters.rfactor   = 1e-2;
	model.parameters.dt        = 0.01;


	// Take 4k screens into account (framebuffer size != window size)
	gfx2::window_size_t window_size;
	window_size.win_width = 1200;
	window_size.win_height = 600;
	window_size.fb_width = -1;	// Will be known later
	window_size.fb_height = -1;


	// Initialise GLFW
	glfwSetErrorCallback(error_callback);

	if (!glfwInit())
		return -1;

	glfwWindowHint(GLFW_SAMPLES, 4);
	glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 2);
	glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 1);

	GLFWwindow* window = gfx2::create_window_at_optimal_size(
		"Demo HSMM batch", window_size.win_width, window_size.win_height
	);

	glfwMakeContextCurrent(window);


	// Setup OpenGL
	gfx2::init();
	glEnable(GL_DEPTH_TEST);
	glEnable(GL_CULL_FACE);
	glEnable(GL_LINE_SMOOTH);
	glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);


	// Setup ImGui
	ImGui::CreateContext();
	ImGui_ImplGlfwGL2_Init(window, true);


	// GUI state
	gui_state_t gui_state;
	gui_state.is_drawing_demonstration = false;
	gui_state.are_parameters_modified = true;
	gui_state.must_recompute = false;
	gui_state.parameter_nb_data = model.parameters.nb_data;
	gui_state.parameter_nb_states = model.parameters.nb_states;
	gui_state.parameter_rfactor = model.parameters.rfactor;


	// Main loop
	vector_list_t current_trajectory;
	std::vector<vector_list_t> original_trajectories;
	matrix_list_t demonstrations;
	mat reproduction;

	while (!glfwWindowShouldClose(window)){
		glfwPollEvents();

		// Handling of the resizing of the window
		gfx2::window_size_t previous_size;

		gfx2::window_result_t window_result =
			gfx2::handle_window_resizing(window, &window_size, &previous_size);

		if (window_result == gfx2::INVALID_SIZE)
			continue;

		if (window_result == gfx2::WINDOW_RESIZED) {

			// Rescale the demonstrations so they stay in the window
			float scale_x = (float) window_size.fb_width / previous_size.fb_width;
			float scale_y = (float) window_size.fb_height / previous_size.fb_height;

			for (size_t i = 0; i < original_trajectories.size(); ++i) {
				for (size_t j = 0; j < original_trajectories[i].size(); ++j) {
					original_trajectories[i][j](0) *= scale_x;
					original_trajectories[i][j](1) *= scale_y;
				}
			}

			gui_state.are_parameters_modified = true;
		}


		// If the parameters changed, learn the model again
		if (gui_state.are_parameters_modified) {

			if (!demonstrations.empty() && (demonstrations[0].n_cols != gui_state.parameter_nb_data)) {
				demonstrations.clear();

				for (size_t i = 0; i < original_trajectories.size(); ++i) {
					mat sampled_trajectory = sample_trajectory(original_trajectories[i],
															   gui_state.parameter_nb_data);
					sampled_trajectory.row(0) /= window_size.fb_width;
					sampled_trajectory.row(1) /= window_size.fb_height;

					demonstrations.push_back(sampled_trajectory);
				}
			}

			model.parameters.nb_data = gui_state.parameter_nb_data;
			model.parameters.nb_states = gui_state.parameter_nb_states;
			model.parameters.rfactor = gui_state.parameter_rfactor;

			gui_state.are_parameters_modified = false;
			gui_state.must_recompute = !demonstrations.empty();
		}

		if (!demonstrations.empty() && gui_state.must_recompute) {
			learn(demonstrations, model);

			mat all_start_points(2, demonstrations.size());
			for (int i = 0; i < demonstrations.size(); ++i)
				all_start_points(span::all, i) = demonstrations[i](span::all, 0);

			reproduction = compute_LQR(model, mean(all_start_points, 1));

			gui_state.must_recompute = false;
		}


		// Start the rendering
		ImGui_ImplGlfwGL2_NewFrame();

		glViewport(0, 0, window_size.fb_width, window_size.fb_height);
		glClearColor(1.0f, 1.0f, 1.0f, 1.0f);
		glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

		glMatrixMode( GL_PROJECTION );
		glLoadIdentity();
		glOrtho(0, window_size.fb_width, 0,  window_size.fb_height, -1., 1.);
		glMatrixMode( GL_MODELVIEW );
		glLoadIdentity();

		glPushMatrix();

		// Draw the GMM states
		if (!model.mu.empty()) {
			for (int i = 0; i < model.parameters.nb_states; ++i) {
				glClear(GL_DEPTH_BUFFER_BIT);

				vec mu(2);
				mu(0) = model.mu[i](0) * window_size.fb_width;
				mu(1) = model.mu[i](1) * window_size.fb_height;

				mat scaling({
					{ (double) window_size.fb_width, 0.0 },
					{ 0.0, (double) window_size.fb_height }
				});

				gfx2::draw_gaussian(
					conv_to<fvec>::from(COLORS.row(i % COLORS.n_rows).t()), mu,
					scaling * model.sigma[i] * scaling.t()
				);
			}

			glClear(GL_DEPTH_BUFFER_BIT);
		}

		// Draw the currently created demonstration (if any)
		if (current_trajectory.size() > 1)
			gfx2::draw_line(fvec({0.0f, 0.0f, 0.0f}), current_trajectory);

		// Draw the demonstrations
		for (size_t i = 0; i < demonstrations.size(); ++i) {
			mat datapoints = demonstrations[i];
			datapoints.row(0) *= window_size.fb_width;
			datapoints.row(1) *= window_size.fb_height;

			gfx2::draw_line(fvec({0.3f, 0.3f, 0.3f}), datapoints);
		}

		// Draw the reproduction
		if (!demonstrations.empty()) {
			mat scaled_reproduction = reproduction;
			scaled_reproduction.row(0) *= window_size.fb_width;
			scaled_reproduction.row(1) *= window_size.fb_height;

			glLineWidth(4.0f);
			gfx2::draw_line(fvec({1.0f, 0.0f, 0.0f}), scaled_reproduction);
			glLineWidth(1.0f);
		}

		glPopMatrix();


		// Control panel GUI
		ImGui::SetNextWindowPos(ImVec2(2,2));
		ImGui::SetNextWindowSize(ImVec2(500, 126));

		ImGui::Begin("Control Panel", NULL,
					 ImGuiWindowFlags_NoTitleBar|ImGuiWindowFlags_NoResize|
					 ImGuiWindowFlags_NoMove|ImGuiWindowFlags_NoSavedSettings
		);

		ImGui::Text("Left-click to collect demonstrations");
		ImGui::SliderInt("Nb states", &gui_state.parameter_nb_states, 2, 20);
		ImGui::SliderInt("Nb data", &gui_state.parameter_nb_data, 100, 300);
		ImGui::SliderFloat("LQR control cost", &gui_state.parameter_rfactor, 1e-3, 1e-1);

		if (ImGui::Button("Apply"))
			gui_state.are_parameters_modified = true;

		ImGui::SameLine();

		if (ImGui::Button("Clear")) {
			demonstrations.clear();
			original_trajectories.clear();
			model.mu.clear();
			model.sigma.clear();
			model.transitions.clear();
			model.states_priors.clear();
			model.H.clear();
		}

		ImGui::End();


		// GUI rendering
		ImGui::Render();
		ImGui_ImplGlfwGL2_RenderDrawData(ImGui::GetDrawData());

		// Swap buffers
		glfwSwapBuffers(window);

		// Keyboard input
		if (ImGui::IsKeyPressed(GLFW_KEY_ESCAPE))
			break;


		if (!gui_state.is_drawing_demonstration) {
			// Left click: start a new demonstration (only if not on the UI and in the
			// demonstrations viewport)
			if (ImGui::IsMouseClicked(GLFW_MOUSE_BUTTON_1) && !ImGui::GetIO().WantCaptureMouse) {
				double mouse_x, mouse_y;
				glfwGetCursorPos(window, &mouse_x, &mouse_y);

				gui_state.is_drawing_demonstration = true;

				vec coords = gfx2::ui2fb({ mouse_x, mouse_y }, window_size);
				current_trajectory.push_back(coords);
			}
		} else if (gui_state.is_drawing_demonstration) {
			double mouse_x, mouse_y;
			glfwGetCursorPos(window, &mouse_x, &mouse_y);

			vec coords = gfx2::ui2fb({ mouse_x, mouse_y }, window_size);

			vec last_point = current_trajectory[current_trajectory.size() - 1];
			vec diff = abs(coords - last_point);

			if ((diff(0) > 1e-6) && (diff(1) > 1e-6))
				current_trajectory.push_back(coords);

			// Left mouse button release: end the demonstration creation
			if (!ImGui::IsMouseDown(GLFW_MOUSE_BUTTON_1)) {
				gui_state.is_drawing_demonstration = false;

				if (current_trajectory.size() > 1) {
					mat sampled_trajectory = sample_trajectory(
						current_trajectory, gui_state.parameter_nb_data
					);

					sampled_trajectory.row(0) /= window_size.fb_width;
					sampled_trajectory.row(1) /= window_size.fb_height;

					demonstrations.push_back(sampled_trajectory);

					original_trajectories.push_back(current_trajectory);

					gui_state.must_recompute = true;
				}

				current_trajectory.clear();
			}
		}
	}

	// Cleanup
	ImGui_ImplGlfwGL2_Shutdown();
	glfwTerminate();

	return 0;
}
